# (C) Crown Copyright, Met Office. All rights reserved.
#
# This file is part of 'IMPROVER' and is released under the BSD 3-Clause license.
# See LICENSE in the root of the repository for full licensing details.
"""Module containing percentiling classes."""

from typing import List, Optional, Union

import iris
import numpy as np
from iris.cube import Cube
from iris.exceptions import CoordinateNotFoundError

from improver import BasePlugin
from improver.constants import DEFAULT_PERCENTILES
from improver.metadata.constants.time_types import TIME_COORDS
from improver.metadata.probabilistic import find_percentile_coordinate
from improver.metadata.utilities import enforce_time_point_standard
from improver.utilities.cube_manipulation import collapsed


class PercentileConverter(BasePlugin):
    """Plugin for converting from a set of values to a PDF.

    Generate percentiles together with min, max, mean, stdev.

    """

    def __init__(
        self,
        collapse_coord: Union[str, List[str]],
        percentiles: Optional[List[float]] = None,
        retained_coordinates: Optional[Union[str, List[str]]] = None,
        fast_percentile_method: bool = True,
    ) -> None:
        """
        Create a PDF plugin with a given source plugin.

        Args:
            collapse_coord:
                The name of the coordinate(s) to collapse over. This
                coordinate(s) will no longer be present on the output cube, as
                it will have been replaced by the percentile coordinate.
            percentiles:
                Percentile values at which to calculate; if not provided uses
                DEFAULT_PERCENTILES. (optional)
            retained_coordinates:
                Optional list of collapsed coordinates that should be retained
                in their new scalar form. The default behaviour is to remove
                the scalar coordinates that result from coordinate collapse.
            fast_percentile_method:
                If True use the numpy percentile method within Iris, which is
                much faster than scipy, but cannot handle masked data.

        Raises:
            TypeError: If collapse_coord is not a string.
        """
        if not isinstance(collapse_coord, list):
            collapse_coord = [collapse_coord]
        if not all(isinstance(test_coord, str) for test_coord in collapse_coord):
            raise TypeError(
                "collapse_coord is {!r}, which is not a string "
                "as is expected.".format(collapse_coord)
            )

        if percentiles is not None:
            self.percentiles = [np.float32(value) for value in percentiles]
        else:
            self.percentiles = [np.float32(value) for value in DEFAULT_PERCENTILES]

        # Collapsing multiple coordinates results in a new percentile
        # coordinate, its name suffixed by the original coordinate names. Such
        # a collapse is cummutative (i.e. coordinate order doesn't matter).
        # However the coordinates are sorted here such that the resulting
        # percentile coordinate has a consistent name regardless of the order
        # in which the user provides the original coordinate names.
        self.collapse_coord = sorted(collapse_coord)
        self.retained_coordinates = retained_coordinates
        self.fast_percentile_method = fast_percentile_method

    def __repr__(self) -> str:
        """Represent the configured plugin instance as a string."""
        desc = "<PercentileConverter: collapse_coord={}, percentiles={}".format(
            self.collapse_coord, self.percentiles
        )
        return desc

    def process(self, cube: Cube) -> Cube:
        """
        Create a cube containing the percentiles as a new dimension.

        What's generated by default is:
            * 15 percentiles - (0%, 5%, 10%, 20%, 25%, 30%, 40%, 50%, 60%,
              70%, 75%, 80%, 90%, 95%, 100%)

        Args:
            cube:
                Given the collapse coordinate, convert the set of values
                along that coordinate into a PDF and extract percentiles.

        Returns:
            A single merged cube of all the cubes produced by each
            percentile collapse.
        """
        # Store data type and enforce the same type on return.
        data_type = cube.dtype
        # Test that collapse coords are present in cube before proceeding.
        n_collapse_coords = len(self.collapse_coord)
        n_valid_coords = sum(
            [
                test_coord == coord.name()
                for coord in cube.coords()
                for test_coord in self.collapse_coord
            ]
        )
        # Rename the percentile coordinate to "percentile" and also
        # makes sure that the associated unit is %.
        if n_valid_coords == n_collapse_coords:
            result = collapsed(
                cube,
                self.collapse_coord,
                iris.analysis.PERCENTILE,
                percent=self.percentiles,
                fast_percentile_method=self.fast_percentile_method,
            )

            result.data = result.data.astype(data_type)

            remove_crds = self.collapse_coord
            if self.retained_coordinates is not None:
                remove_crds = [
                    crd
                    for crd in self.collapse_coord
                    if crd not in self.retained_coordinates
                ]
            for coord in remove_crds:
                result.remove_coord(coord)

            # If a time related coordinate has been collapsed we need to
            # enforce the IMPROVER standard of a coordinate point that aligns
            # with the upper bound of the period.
            if any([crd in TIME_COORDS for crd in self.collapse_coord]):
                enforce_time_point_standard(result)

            percentile_coord = find_percentile_coordinate(result)
            result.coord(percentile_coord).rename("percentile")
            result.coord(percentile_coord).units = "%"
            return result

        raise CoordinateNotFoundError(
            "Coordinate '{}' not found in cube passed to {}.".format(
                self.collapse_coord, self.__class__.__name__
            )
        )
